# Load the required packages.
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import tensorflow as tf
from keras import backend as K
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, LeakyReLU, BatchNormalization
from keras.metrics import AUC
# Load the dataset.
COVID_data = pd.read_csv('COVID_data.csv')
COVID_data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 10020 entries, 0 to 10019 Data columns (total 16 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age.at.diagnosis 10020 non-null object 1 Sex 10020 non-null object 2 Month.first.diagnosis 10020 non-null object 3 Year.first.diagnosis 10020 non-null int64 4 Uncomplicated.phase 10020 non-null object 5 Complicated.phase 10020 non-null object 6 Critical.phase 10020 non-null object 7 Recovery.phase 10020 non-null object 8 Vasopressors.in.complicated.phase 4725 non-null object 9 Vasopressors.in.critical.phase 1856 non-null object 10 Invasive.ventilation.in.critical.phase 1856 non-null object 11 Superinfection.in.uncomplicated.phase 8064 non-null object 12 Superinfection.in.complicated.phase 4725 non-null object 13 Superinfection.in.critical.phase 1856 non-null object 14 Symptoms.in.recovery.phase 4941 non-null object 15 Last.known.patient.status 10020 non-null object dtypes: int64(1), object(15) memory usage: 1.2+ MB
# Fill any blanks in the data with n/a.
COVID_data = COVID_data.fillna("n/a")
COVID_data.head(20)
| Age.at.diagnosis | Sex | Month.first.diagnosis | Year.first.diagnosis | Uncomplicated.phase | Complicated.phase | Critical.phase | Recovery.phase | Vasopressors.in.complicated.phase | Vasopressors.in.critical.phase | Invasive.ventilation.in.critical.phase | Superinfection.in.uncomplicated.phase | Superinfection.in.complicated.phase | Superinfection.in.critical.phase | Symptoms.in.recovery.phase | Last.known.patient.status | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 26 - 45 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 1 | 46 - 65 years | Female | <= 3 | 2020 | yes | yes | no | yes | no | n/a | n/a | none | none | n/a | yes | Recovered |
| 2 | <= 25 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 3 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | no | yes | n/a | n/a | n/a | none | n/a | n/a | no | Recovered |
| 4 | 26 - 45 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 5 | 46 - 65 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 6 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | unknown/missing | n/a | n/a | n/a | Not recovered |
| 7 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 8 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | unknown/missing | n/a | n/a | n/a | Not recovered |
| 9 | 46 - 65 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | unknown/missing | n/a | n/a | n/a | Not recovered |
| 10 | 46 - 65 years | Female | <= 3 | 2020 | yes | yes | no | no | no | n/a | n/a | none | none | n/a | n/a | Recovered |
| 11 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 12 | 26 - 45 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 13 | 26 - 45 years | Female | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 14 | 26 - 45 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | Recovered |
| 15 | 26 - 45 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | none | n/a | n/a | n/a | unknown/missing |
| 16 | 26 - 45 years | Male | <= 3 | 2020 | yes | no | no | no | n/a | n/a | n/a | unknown/missing | n/a | n/a | n/a | Recovered |
| 17 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | yes | no | n/a | yes | yes | none | n/a | bacterial | n/a | Recovered |
| 18 | 46 - 65 years | Male | <= 3 | 2020 | yes | yes | yes | no | yes | yes | yes | none | none | none | n/a | Recovered |
| 19 | 46 - 65 years | Male | <= 3 | 2020 | yes | no | yes | no | n/a | yes | yes | none | n/a | none | n/a | Recovered |
# Use count plots to visualize the distribution of the data.
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.countplot(x="Age.at.diagnosis", data= COVID_data,
order=['<= 25 years','26 - 45 years','46 - 65 years','66 - 85 years','> 85 years'])
plt.show()
# Most of the individuals in this survey are 46-85 years old.
## Very few are under 25, and very few are over 85.
sns.countplot(x="Sex", data= COVID_data)
plt.show()
# There are more males represented in this dataset than females.
sns.countplot(x="Month.first.diagnosis", data= COVID_data,
order=['1','2','<= 3','3','4','5','6','7','8','9','10','11','12'])
plt.show()
# Most patients were first diagnosed 4 or fewer months ago.
## A substantial portion were first diagnosed 10-12 months ago.
sns.countplot(x="Year.first.diagnosis", data= COVID_data)
plt.show()
# Most individuals in this survey were first diagnosed in 2020.
sns.countplot(x="Uncomplicated.phase", data= COVID_data, order=['no','yes'])
plt.show()
# The majority of individuals infected with COVID-19 experience the uncomplicated phase.
sns.countplot(x="Complicated.phase", data= COVID_data)
plt.show()
# A substantial portion of individuals experience the complicated phase.
sns.countplot(x="Critical.phase", data= COVID_data)
plt.show()
# The majority of individuals do not enter the critical phase.
sns.countplot(x="Recovery.phase", data= COVID_data)
plt.show()
# A substantial portion of individuals enter the recovery phase.
## NOTE: some patients may recover without being listed as having entered the recovery phase.
## Read more about how the four different phases are diagnosed here: https://leoss.net/statistics/.
sns.countplot(x="Vasopressors.in.complicated.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['no','yes', 'unknown/missing'])
plt.show()
sns.countplot(x="Vasopressors.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['no','yes', 'unknown/missing'])
plt.show()
# Far more patients required vasopressors in the critical phase than in the complicated phase.
## Vasopressors are used to constrict blood vessels for people with low blood pressure.
sns.countplot(x="Invasive.ventilation.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['no','yes', 'unknown/missing'])
plt.show()
# A substantial portion of individuals required invasive ventilation in the critical phase.
## Invasive ventilation involves inserting a tube in the throat to take over respiratory function.
sns.countplot(x="Superinfection.in.uncomplicated.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()
sns.countplot(x="Superinfection.in.complicated.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()
sns.countplot(x="Superinfection.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()
# A much larger portion of individuals in the critical phase experience superinfections.
## Bacterial superinfections tend to be the most common.
sns.countplot(x="Symptoms.in.recovery.phase", data= COVID_data[-(COVID_data == 'n/a')],
order=['no','yes', 'unknown/missing'])
plt.show()
# Most patients did not experience symptoms in the recovery phase.
sns.countplot(x="Last.known.patient.status", data= COVID_data[-(COVID_data == 'unknown/missing')])
plt.show()
# The majority of the patients in this dataset recovered from COVID-19.
## However, I am interested in predicting the chance that someone will die from COVID-19 based on this data.
## I will therefore be looking at those categorized as "Dead from COVID-19".
# Make all variables numeric, on a scale of 0 to 1.
## Intermediate values should be evenly spaced from 0 to 1.
## I will be cleaning and scaling the data manually due to some inconsistencies in it.
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '<= 25 years'), 'Age.at.diagnosis'] = 0
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '26 - 45 years'), 'Age.at.diagnosis'] = .25
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '46 - 65 years'), 'Age.at.diagnosis'] = .5
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '66 - 85 years'), 'Age.at.diagnosis'] = .75
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '> 85 years'), 'Age.at.diagnosis'] = 1
COVID_data.loc[(COVID_data['Sex'] == 'Female'), 'Sex'] = 0
COVID_data.loc[(COVID_data['Sex'] == 'Male'), 'Sex'] = 1
# I will be combining "<= 3" with "2" for the purposes of this model.
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '<= 3'), 'Month.first.diagnosis'] = 0.0909
COVID_data['Month.first.diagnosis'].astype(float)
0 0.0909
1 0.0909
2 0.0909
3 0.0909
4 0.0909
...
10015 8.0000
10016 8.0000
10017 8.0000
10018 7.0000
10019 12.0000
Name: Month.first.diagnosis, Length: 10020, dtype: float64
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '1'), 'Month.first.diagnosis'] = 0
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '2'), 'Month.first.diagnosis'] = .0909
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '3'), 'Month.first.diagnosis'] = .1818
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '4'), 'Month.first.diagnosis'] = .2727
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '5'), 'Month.first.diagnosis'] = .3636
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '6'), 'Month.first.diagnosis'] = .4545
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '7'), 'Month.first.diagnosis'] = .5454
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '8'), 'Month.first.diagnosis'] = .6363
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '9'), 'Month.first.diagnosis'] = .7272
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '10'), 'Month.first.diagnosis'] = .8181
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '11'), 'Month.first.diagnosis'] = 0.909
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '12'), 'Month.first.diagnosis'] = 1
COVID_data.loc[(COVID_data['Year.first.diagnosis'] == 2020), 'Year.first.diagnosis'] = 0
COVID_data.loc[(COVID_data['Year.first.diagnosis'] == 2021), 'Year.first.diagnosis'] = 1
# "N/A"s and "No"s are 0s, "Yes"s are 1s.
COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'n/a'), 'Uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'no'), 'Uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'yes'), 'Uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Complicated.phase'] == 'n/a'), 'Complicated.phase'] = 0
COVID_data.loc[(COVID_data['Complicated.phase'] == 'no'), 'Complicated.phase'] = 0
COVID_data.loc[(COVID_data['Complicated.phase'] == 'yes'), 'Complicated.phase'] = 1
COVID_data.loc[(COVID_data['Critical.phase'] == 'n/a'), 'Critical.phase'] = 0
COVID_data.loc[(COVID_data['Critical.phase'] == 'no'), 'Critical.phase'] = 0
COVID_data.loc[(COVID_data['Critical.phase'] == 'yes'), 'Critical.phase'] = 1
COVID_data.loc[(COVID_data['Critical.phase'] == 'n/a'), 'Recovery.phase'] = 0
COVID_data.loc[(COVID_data['Recovery.phase'] == 'no'), 'Recovery.phase'] = 0
COVID_data.loc[(COVID_data['Recovery.phase'] == 'yes'), 'Recovery.phase'] = 1
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'n/a'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'unknown/missing'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'no'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'yes'), 'Vasopressors.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'n/a'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'unknown/missing'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'no'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'yes'), 'Vasopressors.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'n/a'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'unknown/missing'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'no'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'yes'), 'Invasive.ventilation.in.critical.phase'] = 1
# All forms of superinfections are 1s.
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'n/a'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'unknown/missing'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'none'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'bacterial'), 'Superinfection.in.uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'bacterial&fungal'), 'Superinfection.in.uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'fungal'), 'Superinfection.in.uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'n/a'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'unknown/missing'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'none'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'bacterial'), 'Superinfection.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'bacterial&fungal'), 'Superinfection.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'fungal'), 'Superinfection.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'n/a'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'unknown/missing'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'none'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'bacterial'), 'Superinfection.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'bacterial&fungal'), 'Superinfection.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'fungal'), 'Superinfection.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'n/a'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'unknown/missing'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'no'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'yes'), 'Symptoms.in.recovery.phase'] = 1
# Only "Dead from COVID-19" is a 1, all other instances are 0s.
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Dead from COVID-19'), 'Last.known.patient.status'] = 1
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Dead from other causes'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Not recovered'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Recovered'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'unknown/missing'), 'Last.known.patient.status'] = 0
COVID_data.head(20)
| Age.at.diagnosis | Sex | Month.first.diagnosis | Year.first.diagnosis | Uncomplicated.phase | Complicated.phase | Critical.phase | Recovery.phase | Vasopressors.in.complicated.phase | Vasopressors.in.critical.phase | Invasive.ventilation.in.critical.phase | Superinfection.in.uncomplicated.phase | Superinfection.in.complicated.phase | Superinfection.in.critical.phase | Symptoms.in.recovery.phase | Last.known.patient.status | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.25 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 0.5 | 0 | 0.0909 | 0 | 1 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| 2 | 0 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 0.25 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5 | 0.5 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 6 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 7 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 8 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 9 | 0.5 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 10 | 0.5 | 0 | 0.0909 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 11 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 12 | 0.25 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 13 | 0.25 | 0 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 14 | 0.25 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 15 | 0.25 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 16 | 0.25 | 1 | 0.0909 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 17 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 1 | 0 | 0 |
| 18 | 0.5 | 1 | 0.0909 | 0 | 1 | 1 | 1 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 |
| 19 | 0.5 | 1 | 0.0909 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 |
# Create an alternate dataset of just patients who have entered the complicated phase.
## Use this to determine whether the model is generalizable or not.
complicated_data = COVID_data[COVID_data['Complicated.phase'] == 1]
# Use the full COVID-19 dataset.
COVID_data = COVID_data.drop(columns=['Recovery.phase','Symptoms.in.recovery.phase'])
# Create x and y variables, and train and test sets.
x = COVID_data
y = x.pop('Last.known.patient.status').to_frame()
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
x_train.head()
| Age.at.diagnosis | Sex | Month.first.diagnosis | Year.first.diagnosis | Uncomplicated.phase | Complicated.phase | Critical.phase | Vasopressors.in.complicated.phase | Vasopressors.in.critical.phase | Invasive.ventilation.in.critical.phase | Superinfection.in.uncomplicated.phase | Superinfection.in.complicated.phase | Superinfection.in.critical.phase | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 8998 | 0.5 | 0 | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 1 | 0 | 1 | 1 |
| 9237 | 0.75 | 1 | 0.8181 | 0 | 0 | 1 | 1 | 0 | 1 | 1 | 0 | 0 | 0 |
| 3265 | 0.75 | 1 | 0.2727 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| 9349 | 0.75 | 1 | 0.1818 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5665 | 0.5 | 0 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
y_train.head()
| Last.known.patient.status | |
|---|---|
| 8998 | 1 |
| 9237 | 1 |
| 3265 | 0 |
| 9349 | 0 |
| 5665 | 1 |
y_train = tf.keras.utils.to_categorical(y_train.values, num_classes=2)
y_test = tf.keras.utils.to_categorical(y_test.values, num_classes=2)
s = StandardScaler()
x_train = s.fit_transform(x_train)
x_test = s.transform(x_test)
K.clear_session()
tf.random.set_seed(42)
model_class = Sequential()
model_class.add(Dense(100, activation='relu', input_shape=(x_train.shape[1],)))
model_class.add(Dropout(0.1))
model_class.add(Dense(50, activation='relu'))
model_class.add(BatchNormalization())
model_class.add(Dropout(0.1))
model_class.add(Dense(10, activation='relu'))
model_class.add(BatchNormalization())
model_class.add(Dropout(0.1))
model_class.add(Dense(5, activation='relu'))
model_class.add(Dense(2, activation='sigmoid'))
model_class.compile(loss = 'binary_crossentropy', optimizer='adamax', metrics=[AUC()])
model_class.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 100) 1400
dropout (Dropout) (None, 100) 0
dense_1 (Dense) (None, 50) 5050
batch_normalization (BatchN (None, 50) 200
ormalization)
dropout_1 (Dropout) (None, 50) 0
dense_2 (Dense) (None, 10) 510
batch_normalization_1 (Batc (None, 10) 40
hNormalization)
dropout_2 (Dropout) (None, 10) 0
dense_3 (Dense) (None, 5) 55
dense_4 (Dense) (None, 2) 12
=================================================================
Total params: 7,267
Trainable params: 7,147
Non-trainable params: 120
_________________________________________________________________
binary_class = model_class.fit(x_train, y_train,
validation_split=0.1,
epochs=50, verbose=2,
shuffle=True)
Epoch 1/50 212/212 - 2s - loss: 0.5600 - auc: 0.7987 - val_loss: 0.4688 - val_auc: 0.9319 - 2s/epoch - 10ms/step Epoch 2/50 212/212 - 0s - loss: 0.3687 - auc: 0.9293 - val_loss: 0.3153 - val_auc: 0.9487 - 474ms/epoch - 2ms/step Epoch 3/50 212/212 - 0s - loss: 0.3118 - auc: 0.9441 - val_loss: 0.2858 - val_auc: 0.9519 - 464ms/epoch - 2ms/step Epoch 4/50 212/212 - 1s - loss: 0.2893 - auc: 0.9497 - val_loss: 0.2758 - val_auc: 0.9532 - 644ms/epoch - 3ms/step Epoch 5/50 212/212 - 0s - loss: 0.2836 - auc: 0.9508 - val_loss: 0.2714 - val_auc: 0.9542 - 480ms/epoch - 2ms/step Epoch 6/50 212/212 - 1s - loss: 0.2774 - auc: 0.9530 - val_loss: 0.2656 - val_auc: 0.9555 - 590ms/epoch - 3ms/step Epoch 7/50 212/212 - 1s - loss: 0.2715 - auc: 0.9543 - val_loss: 0.2686 - val_auc: 0.9545 - 685ms/epoch - 3ms/step Epoch 8/50 212/212 - 0s - loss: 0.2720 - auc: 0.9540 - val_loss: 0.2680 - val_auc: 0.9544 - 451ms/epoch - 2ms/step Epoch 9/50 212/212 - 0s - loss: 0.2690 - auc: 0.9552 - val_loss: 0.2651 - val_auc: 0.9554 - 500ms/epoch - 2ms/step Epoch 10/50 212/212 - 0s - loss: 0.2678 - auc: 0.9549 - val_loss: 0.2662 - val_auc: 0.9549 - 425ms/epoch - 2ms/step Epoch 11/50 212/212 - 0s - loss: 0.2657 - auc: 0.9559 - val_loss: 0.2641 - val_auc: 0.9559 - 453ms/epoch - 2ms/step Epoch 12/50 212/212 - 0s - loss: 0.2656 - auc: 0.9560 - val_loss: 0.2633 - val_auc: 0.9563 - 434ms/epoch - 2ms/step Epoch 13/50 212/212 - 0s - loss: 0.2634 - auc: 0.9565 - val_loss: 0.2631 - val_auc: 0.9562 - 445ms/epoch - 2ms/step Epoch 14/50 212/212 - 0s - loss: 0.2616 - auc: 0.9575 - val_loss: 0.2627 - val_auc: 0.9563 - 434ms/epoch - 2ms/step Epoch 15/50 212/212 - 0s - loss: 0.2585 - auc: 0.9582 - val_loss: 0.2644 - val_auc: 0.9561 - 416ms/epoch - 2ms/step Epoch 16/50 212/212 - 0s - loss: 0.2615 - auc: 0.9574 - val_loss: 0.2654 - val_auc: 0.9559 - 437ms/epoch - 2ms/step Epoch 17/50 212/212 - 0s - loss: 0.2570 - auc: 0.9587 - val_loss: 0.2631 - val_auc: 0.9563 - 388ms/epoch - 2ms/step Epoch 18/50 212/212 - 0s - loss: 0.2566 - auc: 0.9589 - val_loss: 0.2626 - val_auc: 0.9564 - 422ms/epoch - 2ms/step Epoch 19/50 212/212 - 0s - loss: 0.2557 - auc: 0.9595 - val_loss: 0.2645 - val_auc: 0.9557 - 441ms/epoch - 2ms/step Epoch 20/50 212/212 - 0s - loss: 0.2569 - auc: 0.9586 - val_loss: 0.2633 - val_auc: 0.9565 - 418ms/epoch - 2ms/step Epoch 21/50 212/212 - 0s - loss: 0.2540 - auc: 0.9596 - val_loss: 0.2648 - val_auc: 0.9555 - 407ms/epoch - 2ms/step Epoch 22/50 212/212 - 0s - loss: 0.2557 - auc: 0.9590 - val_loss: 0.2647 - val_auc: 0.9558 - 400ms/epoch - 2ms/step Epoch 23/50 212/212 - 0s - loss: 0.2569 - auc: 0.9589 - val_loss: 0.2625 - val_auc: 0.9559 - 419ms/epoch - 2ms/step Epoch 24/50 212/212 - 0s - loss: 0.2579 - auc: 0.9585 - val_loss: 0.2630 - val_auc: 0.9560 - 422ms/epoch - 2ms/step Epoch 25/50 212/212 - 0s - loss: 0.2547 - auc: 0.9592 - val_loss: 0.2634 - val_auc: 0.9558 - 415ms/epoch - 2ms/step Epoch 26/50 212/212 - 0s - loss: 0.2533 - auc: 0.9598 - val_loss: 0.2653 - val_auc: 0.9551 - 426ms/epoch - 2ms/step Epoch 27/50 212/212 - 0s - loss: 0.2547 - auc: 0.9592 - val_loss: 0.2660 - val_auc: 0.9546 - 417ms/epoch - 2ms/step Epoch 28/50 212/212 - 0s - loss: 0.2539 - auc: 0.9598 - val_loss: 0.2669 - val_auc: 0.9549 - 398ms/epoch - 2ms/step Epoch 29/50 212/212 - 0s - loss: 0.2539 - auc: 0.9598 - val_loss: 0.2654 - val_auc: 0.9553 - 424ms/epoch - 2ms/step Epoch 30/50 212/212 - 0s - loss: 0.2526 - auc: 0.9603 - val_loss: 0.2662 - val_auc: 0.9550 - 433ms/epoch - 2ms/step Epoch 31/50 212/212 - 0s - loss: 0.2515 - auc: 0.9604 - val_loss: 0.2654 - val_auc: 0.9557 - 417ms/epoch - 2ms/step Epoch 32/50 212/212 - 0s - loss: 0.2516 - auc: 0.9604 - val_loss: 0.2646 - val_auc: 0.9554 - 403ms/epoch - 2ms/step Epoch 33/50 212/212 - 0s - loss: 0.2555 - auc: 0.9590 - val_loss: 0.2659 - val_auc: 0.9548 - 429ms/epoch - 2ms/step Epoch 34/50 212/212 - 0s - loss: 0.2553 - auc: 0.9591 - val_loss: 0.2669 - val_auc: 0.9545 - 414ms/epoch - 2ms/step Epoch 35/50 212/212 - 0s - loss: 0.2520 - auc: 0.9605 - val_loss: 0.2660 - val_auc: 0.9547 - 413ms/epoch - 2ms/step Epoch 36/50 212/212 - 0s - loss: 0.2521 - auc: 0.9602 - val_loss: 0.2657 - val_auc: 0.9547 - 414ms/epoch - 2ms/step Epoch 37/50 212/212 - 0s - loss: 0.2524 - auc: 0.9602 - val_loss: 0.2651 - val_auc: 0.9551 - 463ms/epoch - 2ms/step Epoch 38/50 212/212 - 0s - loss: 0.2518 - auc: 0.9604 - val_loss: 0.2633 - val_auc: 0.9560 - 444ms/epoch - 2ms/step Epoch 39/50 212/212 - 0s - loss: 0.2479 - auc: 0.9618 - val_loss: 0.2636 - val_auc: 0.9555 - 413ms/epoch - 2ms/step Epoch 40/50 212/212 - 0s - loss: 0.2520 - auc: 0.9604 - val_loss: 0.2642 - val_auc: 0.9551 - 438ms/epoch - 2ms/step Epoch 41/50 212/212 - 0s - loss: 0.2492 - auc: 0.9613 - val_loss: 0.2671 - val_auc: 0.9541 - 430ms/epoch - 2ms/step Epoch 42/50 212/212 - 0s - loss: 0.2510 - auc: 0.9605 - val_loss: 0.2661 - val_auc: 0.9545 - 414ms/epoch - 2ms/step Epoch 43/50 212/212 - 0s - loss: 0.2523 - auc: 0.9600 - val_loss: 0.2645 - val_auc: 0.9550 - 419ms/epoch - 2ms/step Epoch 44/50 212/212 - 0s - loss: 0.2504 - auc: 0.9605 - val_loss: 0.2680 - val_auc: 0.9540 - 421ms/epoch - 2ms/step Epoch 45/50 212/212 - 0s - loss: 0.2477 - auc: 0.9612 - val_loss: 0.2651 - val_auc: 0.9541 - 452ms/epoch - 2ms/step Epoch 46/50 212/212 - 0s - loss: 0.2507 - auc: 0.9607 - val_loss: 0.2662 - val_auc: 0.9544 - 440ms/epoch - 2ms/step Epoch 47/50 212/212 - 0s - loss: 0.2503 - auc: 0.9610 - val_loss: 0.2674 - val_auc: 0.9543 - 409ms/epoch - 2ms/step Epoch 48/50 212/212 - 0s - loss: 0.2463 - auc: 0.9619 - val_loss: 0.2689 - val_auc: 0.9540 - 440ms/epoch - 2ms/step Epoch 49/50 212/212 - 0s - loss: 0.2463 - auc: 0.9623 - val_loss: 0.2673 - val_auc: 0.9540 - 429ms/epoch - 2ms/step Epoch 50/50 212/212 - 0s - loss: 0.2461 - auc: 0.9619 - val_loss: 0.2668 - val_auc: 0.9542 - 411ms/epoch - 2ms/step
plt.plot(binary_class.history['auc'])
plt.plot(binary_class.history['val_auc'])
plt.title('AUC vs Epochs')
plt.ylabel('AUC')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.show()
metrics.roc_auc_score(y_test, model_class.predict(x_test))
79/79 [==============================] - 0s 1ms/step
0.9015763512614858
These results are promising. However, since the vast majority of patients with COVID-19 recover from mild cases, it is important to evaluate this model for patients who had more extreme COVID-19 cases. Therefore, I will be using this same model for only the patients who have entered the complicated phase.
complicated_data = complicated_data.drop(columns=['Recovery.phase','Symptoms.in.recovery.phase'])
x_2 = complicated_data
y_2 = x_2.pop('Last.known.patient.status').to_frame()
x_train_2, x_test_2, y_train_2, y_test_2 = train_test_split(x_2, y_2, test_size=0.25, random_state=42)
x_train_2.head()
| Age.at.diagnosis | Sex | Month.first.diagnosis | Year.first.diagnosis | Uncomplicated.phase | Complicated.phase | Critical.phase | Vasopressors.in.complicated.phase | Vasopressors.in.critical.phase | Invasive.ventilation.in.critical.phase | Superinfection.in.uncomplicated.phase | Superinfection.in.complicated.phase | Superinfection.in.critical.phase | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5140 | 0.5 | 1 | 0.8181 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5175 | 0.5 | 1 | 0.909 | 0 | 1 | 1 | 1 | 0 | 1 | 1 | 0 | 0 | 1 |
| 4458 | 0.5 | 0 | 0.0909 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 |
| 2252 | 0.75 | 1 | 0.0909 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 7481 | 1 | 0 | 0.909 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
y_train_2 = tf.keras.utils.to_categorical(y_train_2.values, num_classes=2)
y_test_2 = tf.keras.utils.to_categorical(y_test_2.values, num_classes=2)
x_train_2 = s.fit_transform(x_train_2)
x_test_2 = s.transform(x_test_2)
K.clear_session()
tf.random.set_seed(42)
model_class_2 = Sequential()
model_class_2.add(Dense(100, activation='relu', input_shape=(x_train_2.shape[1],)))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(50, activation='relu'))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(10, activation='relu'))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(5, activation='relu'))
model_class_2.add(Dense(2, activation='sigmoid'))
model_class_2.compile(loss = 'binary_crossentropy', optimizer='adamax', metrics=[AUC()])
model_class_2.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 100) 1400
batch_normalization (BatchN (None, 100) 400
ormalization)
dropout (Dropout) (None, 100) 0
dense_1 (Dense) (None, 50) 5050
batch_normalization_1 (Batc (None, 50) 200
hNormalization)
dropout_1 (Dropout) (None, 50) 0
dense_2 (Dense) (None, 10) 510
batch_normalization_2 (Batc (None, 10) 40
hNormalization)
dropout_2 (Dropout) (None, 10) 0
dense_3 (Dense) (None, 5) 55
dense_4 (Dense) (None, 2) 12
=================================================================
Total params: 7,667
Trainable params: 7,347
Non-trainable params: 320
_________________________________________________________________
binary_class = model_class_2.fit(x_train_2, y_train_2,
validation_split=0.1,
epochs=50, verbose=2,
shuffle=True)
Epoch 1/50 100/100 - 2s - loss: 0.6812 - auc: 0.6434 - val_loss: 0.6026 - val_auc: 0.8822 - 2s/epoch - 17ms/step Epoch 2/50 100/100 - 0s - loss: 0.5478 - auc: 0.8122 - val_loss: 0.5161 - val_auc: 0.9077 - 283ms/epoch - 3ms/step Epoch 3/50 100/100 - 0s - loss: 0.4878 - auc: 0.8618 - val_loss: 0.4475 - val_auc: 0.9120 - 253ms/epoch - 3ms/step Epoch 4/50 100/100 - 0s - loss: 0.4533 - auc: 0.8775 - val_loss: 0.4061 - val_auc: 0.9160 - 271ms/epoch - 3ms/step Epoch 5/50 100/100 - 0s - loss: 0.4364 - auc: 0.8833 - val_loss: 0.3855 - val_auc: 0.9199 - 246ms/epoch - 2ms/step Epoch 6/50 100/100 - 0s - loss: 0.4259 - auc: 0.8863 - val_loss: 0.3741 - val_auc: 0.9203 - 227ms/epoch - 2ms/step Epoch 7/50 100/100 - 0s - loss: 0.4137 - auc: 0.8921 - val_loss: 0.3713 - val_auc: 0.9201 - 239ms/epoch - 2ms/step Epoch 8/50 100/100 - 0s - loss: 0.4120 - auc: 0.8937 - val_loss: 0.3642 - val_auc: 0.9222 - 250ms/epoch - 3ms/step Epoch 9/50 100/100 - 0s - loss: 0.4067 - auc: 0.8954 - val_loss: 0.3629 - val_auc: 0.9224 - 246ms/epoch - 2ms/step Epoch 10/50 100/100 - 0s - loss: 0.3985 - auc: 0.8995 - val_loss: 0.3569 - val_auc: 0.9247 - 265ms/epoch - 3ms/step Epoch 11/50 100/100 - 0s - loss: 0.3991 - auc: 0.8995 - val_loss: 0.3526 - val_auc: 0.9253 - 240ms/epoch - 2ms/step Epoch 12/50 100/100 - 0s - loss: 0.4019 - auc: 0.8985 - val_loss: 0.3553 - val_auc: 0.9245 - 225ms/epoch - 2ms/step Epoch 13/50 100/100 - 0s - loss: 0.3970 - auc: 0.8998 - val_loss: 0.3534 - val_auc: 0.9260 - 229ms/epoch - 2ms/step Epoch 14/50 100/100 - 0s - loss: 0.3955 - auc: 0.9012 - val_loss: 0.3493 - val_auc: 0.9274 - 244ms/epoch - 2ms/step Epoch 15/50 100/100 - 0s - loss: 0.3916 - auc: 0.9032 - val_loss: 0.3491 - val_auc: 0.9279 - 232ms/epoch - 2ms/step Epoch 16/50 100/100 - 0s - loss: 0.3852 - auc: 0.9061 - val_loss: 0.3507 - val_auc: 0.9264 - 227ms/epoch - 2ms/step Epoch 17/50 100/100 - 0s - loss: 0.3918 - auc: 0.9030 - val_loss: 0.3492 - val_auc: 0.9275 - 249ms/epoch - 2ms/step Epoch 18/50 100/100 - 0s - loss: 0.3870 - auc: 0.9050 - val_loss: 0.3514 - val_auc: 0.9260 - 251ms/epoch - 3ms/step Epoch 19/50 100/100 - 0s - loss: 0.3870 - auc: 0.9057 - val_loss: 0.3548 - val_auc: 0.9247 - 231ms/epoch - 2ms/step Epoch 20/50 100/100 - 0s - loss: 0.3846 - auc: 0.9069 - val_loss: 0.3546 - val_auc: 0.9246 - 232ms/epoch - 2ms/step Epoch 21/50 100/100 - 0s - loss: 0.3840 - auc: 0.9066 - val_loss: 0.3544 - val_auc: 0.9248 - 247ms/epoch - 2ms/step Epoch 22/50 100/100 - 0s - loss: 0.3914 - auc: 0.9044 - val_loss: 0.3542 - val_auc: 0.9249 - 236ms/epoch - 2ms/step Epoch 23/50 100/100 - 0s - loss: 0.3873 - auc: 0.9056 - val_loss: 0.3533 - val_auc: 0.9251 - 230ms/epoch - 2ms/step Epoch 24/50 100/100 - 0s - loss: 0.3899 - auc: 0.9049 - val_loss: 0.3536 - val_auc: 0.9242 - 247ms/epoch - 2ms/step Epoch 25/50 100/100 - 0s - loss: 0.3797 - auc: 0.9097 - val_loss: 0.3568 - val_auc: 0.9226 - 268ms/epoch - 3ms/step Epoch 26/50 100/100 - 0s - loss: 0.3792 - auc: 0.9096 - val_loss: 0.3558 - val_auc: 0.9234 - 241ms/epoch - 2ms/step Epoch 27/50 100/100 - 0s - loss: 0.3848 - auc: 0.9069 - val_loss: 0.3545 - val_auc: 0.9242 - 260ms/epoch - 3ms/step Epoch 28/50 100/100 - 0s - loss: 0.3831 - auc: 0.9076 - val_loss: 0.3534 - val_auc: 0.9244 - 234ms/epoch - 2ms/step Epoch 29/50 100/100 - 0s - loss: 0.3841 - auc: 0.9066 - val_loss: 0.3550 - val_auc: 0.9237 - 245ms/epoch - 2ms/step Epoch 30/50 100/100 - 0s - loss: 0.3815 - auc: 0.9079 - val_loss: 0.3561 - val_auc: 0.9232 - 242ms/epoch - 2ms/step Epoch 31/50 100/100 - 0s - loss: 0.3797 - auc: 0.9087 - val_loss: 0.3584 - val_auc: 0.9217 - 224ms/epoch - 2ms/step Epoch 32/50 100/100 - 0s - loss: 0.3811 - auc: 0.9095 - val_loss: 0.3562 - val_auc: 0.9222 - 227ms/epoch - 2ms/step Epoch 33/50 100/100 - 0s - loss: 0.3777 - auc: 0.9104 - val_loss: 0.3555 - val_auc: 0.9221 - 238ms/epoch - 2ms/step Epoch 34/50 100/100 - 0s - loss: 0.3726 - auc: 0.9130 - val_loss: 0.3547 - val_auc: 0.9230 - 271ms/epoch - 3ms/step Epoch 35/50 100/100 - 0s - loss: 0.3716 - auc: 0.9130 - val_loss: 0.3531 - val_auc: 0.9238 - 243ms/epoch - 2ms/step Epoch 36/50 100/100 - 0s - loss: 0.3764 - auc: 0.9104 - val_loss: 0.3535 - val_auc: 0.9239 - 249ms/epoch - 2ms/step Epoch 37/50 100/100 - 0s - loss: 0.3794 - auc: 0.9093 - val_loss: 0.3534 - val_auc: 0.9243 - 235ms/epoch - 2ms/step Epoch 38/50 100/100 - 0s - loss: 0.3717 - auc: 0.9129 - val_loss: 0.3542 - val_auc: 0.9229 - 259ms/epoch - 3ms/step Epoch 39/50 100/100 - 0s - loss: 0.3719 - auc: 0.9130 - val_loss: 0.3559 - val_auc: 0.9220 - 244ms/epoch - 2ms/step Epoch 40/50 100/100 - 0s - loss: 0.3731 - auc: 0.9122 - val_loss: 0.3543 - val_auc: 0.9228 - 228ms/epoch - 2ms/step Epoch 41/50 100/100 - 0s - loss: 0.3738 - auc: 0.9127 - val_loss: 0.3532 - val_auc: 0.9234 - 222ms/epoch - 2ms/step Epoch 42/50 100/100 - 0s - loss: 0.3721 - auc: 0.9127 - val_loss: 0.3539 - val_auc: 0.9234 - 251ms/epoch - 3ms/step Epoch 43/50 100/100 - 0s - loss: 0.3783 - auc: 0.9095 - val_loss: 0.3510 - val_auc: 0.9249 - 263ms/epoch - 3ms/step Epoch 44/50 100/100 - 0s - loss: 0.3768 - auc: 0.9100 - val_loss: 0.3535 - val_auc: 0.9244 - 261ms/epoch - 3ms/step Epoch 45/50 100/100 - 0s - loss: 0.3776 - auc: 0.9102 - val_loss: 0.3511 - val_auc: 0.9259 - 232ms/epoch - 2ms/step Epoch 46/50 100/100 - 0s - loss: 0.3669 - auc: 0.9153 - val_loss: 0.3509 - val_auc: 0.9252 - 383ms/epoch - 4ms/step Epoch 47/50 100/100 - 0s - loss: 0.3776 - auc: 0.9106 - val_loss: 0.3536 - val_auc: 0.9242 - 238ms/epoch - 2ms/step Epoch 48/50 100/100 - 0s - loss: 0.3709 - auc: 0.9136 - val_loss: 0.3540 - val_auc: 0.9235 - 242ms/epoch - 2ms/step Epoch 49/50 100/100 - 0s - loss: 0.3640 - auc: 0.9164 - val_loss: 0.3572 - val_auc: 0.9217 - 244ms/epoch - 2ms/step Epoch 50/50 100/100 - 0s - loss: 0.3815 - auc: 0.9081 - val_loss: 0.3578 - val_auc: 0.9207 - 238ms/epoch - 2ms/step
plt.plot(binary_class.history['auc'])
plt.plot(binary_class.history['val_auc'])
plt.title('AUC vs Epochs')
plt.ylabel('AUC')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.show()
metrics.roc_auc_score(y_test_2, model_class_2.predict(x_test_2))
37/37 [==============================] - 0s 1ms/step
0.8311267130341795
The training AUC is a bit lower for the complicated dataset, but performance was still pretty good. COVID-19 is an unpredictable virus, so it is unreasonable to expect 100% performance. However, although the model seems to perform well, it is hard to understand how exactly it works. Thus, I used Shapley values and surrogate decision trees to interpret it.
x = np.asarray(x).astype('float32')
background = x[np.random.choice(x.shape[0],100, replace=False)]
import shap
import warnings
warnings.filterwarnings('ignore')
shap.explainers._deep.deep_tf.op_handlers["AddV2"] = shap.explainers._deep.deep_tf.passthrough
explainer = shap.DeepExplainer(model_class, background)
shap_values = explainer.shap_values(x_train)
expected_value = explainer.expected_value
shap.summary_plot(shap_values, x_train, title="Summary Plot")
The numbers correspond to the variables as follows:
Feature 0: Age
Feature 1: Sex
Feature 2: Month
Feature 3: Year
Feature 4: Uncomplicated phase
Feature 5: Complicated phase
Feature 6: Critical phase
Feature 7: Recovery phase
Feature 8: Vasopressors in complicated phase
Feature 9: Vasopressors in critical phase
Feature 10: Invasive ventilation in critical phase
Feature 11: Superinfection in uncomplicated phase
Feature 12: Superinfection in complicated phase
Feature 13: Superinfection in critical phase
Feature 14: Symptoms in recovery phase
According to the SHAP plot, the three most heavily weighted variables are therefore the patient's age (0), whether the patient entered the critical phase (6), and whether the patient needed vasopressors in the critical phase (9).
Shapley values do not necessarily indicate that these features are the most important. Rather, they have the strongest impact on the model, but may actually be diminishing its performance. However, these findings make sense given real-world observations. I will also be using a surrogate decision tree to evaluate feature importances.
import dalex as dx
explainer = dx.Explainer(model_class, x, y, label='Status')
Preparation of a new explainer is initiated -> data : numpy.ndarray converted to pandas.DataFrame. Columns are set as string numbers. -> data : 10020 rows 13 cols -> target variable : Parameter 'y' was a pandas.DataFrame. Converted to a numpy.ndarray. -> target variable : 10020 values -> model_class : keras.engine.sequential.Sequential (default) -> label : Status -> predict function : <function yhat_tf_classification at 0x000001800CE43B80> will be used (default) 1/1 [==============================] - 0s 31ms/step -> predict function : Accepts pandas.DataFrame and numpy.ndarray. 314/314 [==============================] - 1s 2ms/step -> predicted values : min = 0.0191, mean = 0.211, max = 0.705 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) 314/314 [==============================] - 1s 2ms/step -> residuals : 'residual_function' returns an Error when executed: 'float' object has no attribute 'astype' -> model_info : package keras A new explainer has been created!
explainer.model_parts().plot()
## These seem to correspond with the Shapley values.
32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 3ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 3ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 3ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 953us/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 862us/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 997us/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 2ms/step 32/32 [==============================] - 0s 1ms/step
surrogate_model = explainer.model_surrogate(max_vars=3, max_depth=3)
surrogate_model.performance
16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 940us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 709us/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 690us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 795us/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 806us/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 876us/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 689us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 759us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 917us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 3ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 759us/step 16/16 [==============================] - 0s 992us/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 780us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 794us/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 2ms/step 16/16 [==============================] - 0s 1ms/step 16/16 [==============================] - 0s 2ms/step
| recall | precision | f1 | accuracy | auc | |
|---|---|---|---|---|---|
| DecisionTreeClassifier | 0.767296 | 0.627787 | 0.690566 | 0.967265 | 0.977021 |
surrogate_model.plot()
The decision tree surrogate model uses features 9 (whether vasopressors are needed in the critical phase), 0 (age), and 6 (whether the patient entered the critical phase) to determine if a patient will survive COVID-19. An older patient who needs vasopressors in the critical phase is the most likely to die of the virus according to this model. In all other cases, the model predicts that the individual will likely recover from the virus.
The analysis above corresponds with real-world observations. Risk of dying from COVID-19 increases substantially with age, and patients who require medical treatment during the critical phase (e.g. vasopressors) are less likely to survive.
There is currently no surefire way to predict whether a person will die from COVID-19, since the virus itself is very unpredictable and there are many other underlying risk factors at play. However, this model performs well given what limited information it is based on. Adding more data (e.g. socioeconomic factors like access to healthcare, health conditions like asthma or heart disease, and the variant of COVID-19 the patient was infected with) will likely improve its performance. Again, it is unreasonable to expect 100% accuracy, but a generalizable, highly accurate model could be immensely beneficial for public health.